name: bert_qa

---

name: train

resources:
    accelerators: V100:1

# Assume your working directory is under `~/transformers`.
# To make this example work, please run the following command:
# git clone https://github.com/huggingface/transformers.git ~/transformers -b v4.30.1
workdir: ~/transformers

file_mounts:
    /checkpoint:
        name: test-bert-train-eval # NOTE: Fill in your bucket name
        mode: MOUNT

setup: |
    pip install -e .
    cd examples/pytorch/question-answering/
    pip install -r requirements.txt torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113
    pip install wandb

run: |
    cd examples/pytorch/question-answering/
    python run_qa.py \
    --model_name_or_path bert-base-uncased \
    --dataset_name squad \
    --do_train \
    --per_device_train_batch_size 12 \
    --learning_rate 3e-5 \
    --num_train_epochs 1 \
    --max_seq_length 384 \
    --doc_stride 128 \
    --report_to wandb \
    --run_name $SKYPILOT_TASK_ID \
    --output_dir /checkpoint/bert_qa/$SKYPILOT_TASK_ID \
    --save_total_limit 10 \
    --save_steps 1000
    echo Model saved to /checkpoint/bert_qa/$SKYPILOT_TASK_ID

envs:
    WANDB_API_KEY: # NOTE: Fill in your wandb key

---

name: eval

resources:
    accelerators: T4:1

workdir: ~/transformers

file_mounts:
    /checkpoint:
        name: test-bert-train-eval # NOTE: Fill in your bucket name
        mode: MOUNT

setup: |
    pip install -e .
    cd examples/pytorch/question-answering/
    pip install -r requirements.txt torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113
    pip install wandb

run: |
    FIRST_TASK_UNIQUE_ID=$(echo "$SKYPILOT_TASK_IDS" | sed -n 1p)
    echo Load model from /checkpoint/bert_qa/$FIRST_TASK_UNIQUE_ID
    cd examples/pytorch/question-answering/
    python run_qa.py \
    --model_name_or_path /checkpoint/bert_qa/$FIRST_TASK_UNIQUE_ID \
    --dataset_name squad \
    --do_eval \
    --per_device_train_batch_size 12 \
    --learning_rate 3e-5 \
    --num_train_epochs 50 \
    --max_seq_length 384 \
    --doc_stride 128 \
    --report_to wandb \
    --run_name $SKYPILOT_TASK_ID \
    --output_dir /checkpoint/bert_qa/$FIRST_TASK_UNIQUE_ID \
    --save_total_limit 10 \
    --save_steps 1000

envs:
    WANDB_API_KEY: # NOTE: Fill in your wandb key
